import os
import argparse
import logging
import sys
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.neighbors import KNeighborsClassifier
import pickle
from matplotlib.colors import LinearSegmentedColormap
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import copy
from tqdm import tqdm
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

# Import SpikingJelly from local folder
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'spikingjelly'))
from spikingjelly.activation_based import functional

from dataset.dataloaders import get_dataset
from models.spiking_cnn_models import SpikingCNNModel_Pool
from utils.loss import C2Loss_Classification_CNN_Test, gradient_centralization

# Argument Parser
parser = argparse.ArgumentParser(description='Spiking CNN BSD Training')

def add_gaussian_noise(data, mean=0.0, std=1.0, normalize_mean=None, normalize_std=None):
    """
    Add Gaussian noise to input data, with proper handling of normalization

    The correct order is:
    1. Denormalize data (if already normalized)
    2. Add noise to raw data [0,1]
    3. Clip to valid range [0,1]
    4. Re-normalize

    Args:
        data: Input tensor [N, C, H, W] (may be normalized)
        mean: Mean of Gaussian noise (default: 0.0)
        std: Standard deviation of Gaussian noise (default: 1.0)
        normalize_mean: Mean used for normalization [C] (to denormalize)
        normalize_std: Std used for normalization [C] (to denormalize)

    Returns:
        noisy_data: Input data with added Gaussian noise, re-normalized
    """
    if normalize_mean is not None and normalize_std is not None:
        # Convert to tensors if needed
        if not isinstance(normalize_mean, torch.Tensor):
            normalize_mean = torch.tensor(normalize_mean, device=data.device, dtype=data.dtype)
        if not isinstance(normalize_std, torch.Tensor):
            normalize_std = torch.tensor(normalize_std, device=data.device, dtype=data.dtype)

        # Reshape for broadcasting [1, C, 1, 1]
        normalize_mean = normalize_mean.view(1, -1, 1, 1)
        normalize_std = normalize_std.view(1, -1, 1, 1)

        # Step 1: Denormalize to [0, 1] range
        denormalized_data = data * normalize_std + normalize_mean

        # Step 2: Add noise to raw data
        noise = torch.randn_like(denormalized_data) * std + mean
        noisy_data = denormalized_data + noise

        # Step 3: Clip to valid range [0, 1] to avoid artifacts
        noisy_data = torch.clamp(noisy_data, 0.0, 1.0)

        # Step 4: Re-normalize
        noisy_data = (noisy_data - normalize_mean) / normalize_std

        return noisy_data
    else:
        # Data is not normalized, just add noise directly
        noise = torch.randn_like(data) * std + mean
        noisy_data = data + noise
        return noisy_data


def initialize_model_weights(model, weight_init='default', init_mean=0.0, init_std=0.02):
    """Initialize model weights based on the specified method"""

    def init_weights(m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            if weight_init == 'normal':
                # Normal (Gaussian) initialization
                nn.init.normal_(m.weight, mean=init_mean, std=init_std)
            elif weight_init == 'xavier':
                # Xavier/Glorot initialization
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                else:  # Conv2d
                    nn.init.xavier_uniform_(m.weight)
            elif weight_init == 'kaiming':
                # Kaiming/He initialization (default for ReLU/LIF)
                if isinstance(m, nn.Linear):
                    nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
                else:  # Conv2d
                    nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
            # 'default' uses PyTorch's native initialization

            # Initialize bias to zero if present
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    if weight_init != 'default':
        model.apply(init_weights)
        print(f"🔧 Model weights initialized using {weight_init} method")
        if weight_init == 'normal':
            print(f"   Normal distribution: mean={init_mean}, std={init_std}")


def compute_gradient_norm(parameters):
    """Compute L2 norm of gradients for given parameters"""
    total_norm = 0.0
    param_count = 0
    for p in parameters:
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
            param_count += 1
    
    if param_count == 0:
        return 0.0, 0
    
    return total_norm ** 0.5, param_count


def compute_weight_alignment(model):
    """Compute cosine similarity between forward and backward network weights"""
    alignments = {}
    
    # CNN models have different architectures
    if hasattr(model, 'enc5'):  # Other architectures  
        layer_names = ['enc1', 'enc2', 'enc3', 'enc4', 'enc5', 'output']
    elif hasattr(model, 'enc3') and hasattr(model, 'fc'):  # Updated WideShallow (3-layer)
        layer_names = ['enc1', 'enc2', 'enc3', 'fc', 'output']
    else:  # Old WideShallow or other architectures
        layer_names = ['enc1', 'enc2', 'fc', 'output']
    
    for layer_name in layer_names:
        if hasattr(model, layer_name):
            layer = getattr(model, layer_name)
            
            # Forward weights: flatten all dimensions
            # Backward weights: transpose then flatten  
            if hasattr(layer, 'forward_layer') and hasattr(layer, 'backward_layer'):
                fw_weight = layer.forward_layer.weight.data
                bw_weight = layer.backward_layer.weight.data
                
                # More memory-efficient cosine similarity computation
                # Transpose backward weight before flattening
                bw_weight_t = bw_weight.transpose(-2, -1)
                
                # Compute cosine similarity without creating huge tensors
                # Normalize weights
                fw_norm = fw_weight.flatten().norm()
                bw_norm = bw_weight_t.flatten().norm()
                
                # Compute dot product
                if fw_weight.dim() == 4 and bw_weight.dim() == 4:  # Conv layers
                    # For conv layers, need to match dimensions properly
                    dot_product = (fw_weight.flatten() * bw_weight_t.flatten()).sum()
                else:  # Linear layers
                    dot_product = (fw_weight.flatten() * bw_weight_t.flatten()).sum()
                
                # Compute cosine similarity
                cos_sim = dot_product / (fw_norm * bw_norm + 1e-8)
                alignments[layer_name] = cos_sim.item()
    
    return alignments


def compute_spike_alignment(activations_spikes, signals_spikes):
    """Compute Hamming similarity of LIF output spikes (CNN version) - take best sample"""
    alignments = {}
    
    # Only process middle layers (skip first layer input and last layer output, as they don't go through LIF)
    for i, (act, sig) in enumerate(zip(activations_spikes, signals_spikes)):
        if (act is not None and sig is not None and 
            isinstance(act, torch.Tensor) and isinstance(sig, torch.Tensor)):
            
            # Skip first layer (input) and last layer (output/target), only process middle LIF layers
            if i == 0 or i == len(activations_spikes) - 1:
                continue

            # Handle shape mismatch: remove extra dimensions of size 1
            act_squeezed = act.squeeze()  # Remove all dimensions of size 1
            sig_squeezed = sig.squeeze()  # Remove all dimensions of size 1

            # Check if shapes match
            if act_squeezed.shape != sig_squeezed.shape:
                logger.warning(f"Layer {i}: Shape mismatch after squeeze - forward: {act_squeezed.shape}, backward: {sig_squeezed.shape}. Skipping spike alignment for this layer.")
                continue
            
            # Use squeezed tensors
            act, sig = act_squeezed, sig_squeezed
            logger.info(f"Layer {i}: Using shapes - forward: {act.shape}, backward: {sig.shape}")

            # Verify LIF output is 0-1 spikes, raise error if not
            act_unique = torch.unique(act)
            sig_unique = torch.unique(sig)

            if not (torch.all((act_unique == 0) | (act_unique == 1)) and torch.all((sig_unique == 0) | (sig_unique == 1))):
                raise ValueError(f"Layer {i}: LIF outputs are not 0-1 spikes! Act values: {act_unique.tolist()}, Sig values: {sig_unique.tolist()}")

            # Compute Hamming similarity for each sample, take the best sample
            # act/sig shape: [T, N, C, H, W] or [T, N, features]
            matches = (act == sig).float()  # [T, N, ...]

            # Compute similarity for each sample (average over all dimensions except batch dimension)
            sample_similarities = []
            batch_size = matches.shape[1]

            for sample_idx in range(batch_size):
                sample_match = matches[:, sample_idx]  # [T, ...] for this sample
                sample_similarity = sample_match.mean().item()  # Average over all dimensions for this sample
                sample_similarities.append(sample_similarity)

            # Take the best sample similarity
            best_similarity = max(sample_similarities)
            
            alignments[f'cnn_layer_{i}'] = best_similarity
    
    return alignments


def plot_weight_alignment(weight_data, save_path):
    """Plot weight alignment line chart (clean version, no markers)"""
    plt.figure(figsize=(12, 8))

    for layer_name, values in weight_data.items():
        plt.plot(epoch_list, values, linewidth=3, label=layer_name)  # Remove marker, thicker lines
    
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Cosine Similarity', fontsize=14)
    plt.title('Weight Alignment Between Forward and Backward Networks', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    logger.info(f"Weight alignment plot saved to {save_path}")


def plot_spike_alignment(spike_data, save_path):
    """Plot spike alignment heatmap (clean version, blue=good, yellow=poor)"""
    if not spike_data or not spike_epoch_list:
        logger.warning("No spike alignment data to plot")
        return

    layer_names = list(spike_data.keys())
    data_matrix = np.array([spike_data[layer] for layer in layer_names])
    
    data_min = np.min(data_matrix)
    data_max = np.max(data_matrix)
    
    plt.figure(figsize=(14, 8))
    
    colors = ['#FFB6C1', '#FFA0B4', '#E6E6FA', '#F0F8FF', '#E0F6FF', '#D1E7FF', '#B0E0E6', '#87CEEB', '#6495ED', '#4169E1']
    custom_cmap = LinearSegmentedColormap.from_list('LightRedToGradualBlue', colors, N=256)
    
    ax = sns.heatmap(data_matrix, 
                     xticklabels=spike_epoch_list,
                     yticklabels=layer_names,
                     annot=True, 
                     fmt='.3f',
                     cmap=custom_cmap,  
                     vmin=0.8, vmax=data_max, 
                     cbar_kws={'label': 'Hamming Similarity (Best Sample)'},
                     annot_kws={'fontsize': 11, 'fontweight': 'bold'})  
    
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Layer', fontsize=14)
    plt.title('Spike Alignment Between Forward and Backward Networks', 
              fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    logger.info(f"Spike alignment plot saved to {save_path}")
    logger.info(f"Spike alignment color range: [{data_min:.3f}, {data_max:.3f}] (using actual data range)")


def extract_layer_features(model, dataloader, device, max_samples=500):
    """
    提取CNN前向和后向网络各层的特征用于t-SNE可视化
    
    Args:
        model: 训练好的模型
        dataloader: 数据加载器  
        device: 设备
        max_samples: 最大样本数（避免内存问题）
        
    Returns:
        forward_features: {layer_name: features_array} - 前向网络特征
        backward_features: {layer_name: features_array} - 后向网络特征  
        targets: 对应的标签
    """
    was_training = model.training
    model.eval()
    forward_features = {}
    backward_features = {}
    all_targets = []
    
    def get_forward_activation(name):
        def hook(model, input, output):
            if name not in forward_features:
                forward_features[name] = []
            # [T, N, C, H, W] -> [N, C, H, W]
            if len(output.shape) == 5:
                feat = output.mean(dim=0)  
            else:
                feat = output
            
            # [N, features]  
            if len(feat.shape) == 4:  # [N, C, H, W]
                feat_flat = feat.view(feat.shape[0], -1)
            elif len(feat.shape) == 2:  # [N, features]
                feat_flat = feat
            else:
                feat_flat = feat.view(feat.shape[0], -1)
            
            forward_features[name].append(feat_flat.detach().cpu())
        return hook
    
    def get_backward_activation(name):
        def hook(model, input, output):
            if name not in backward_features:
                backward_features[name] = []
            # [T, N, C, H, W] -> [N, C, H, W]
            if len(output.shape) == 5:
                feat = output.mean(dim=0)  
            else:
                feat = output
            
            # [N, features]  
            if len(feat.shape) == 4:  # [N, C, H, W]
                feat_flat = feat.view(feat.shape[0], -1)
            elif len(feat.shape) == 2:  # [N, features]
                feat_flat = feat
            else:
                feat_flat = feat.view(feat.shape[0], -1)
            
            backward_features[name].append(feat_flat.detach().cpu())
        return hook
    
    hooks = []
    
    if hasattr(model, 'enc5'):  # 5-layer architectures (CNN, CNN_Pool, etc.)
        layer_names = ['enc1', 'enc2', 'enc3', 'enc4', 'enc5']
    elif hasattr(model, 'enc3') and hasattr(model, 'fc'):  # Updated WideShallow (3-layer)
        layer_names = ['enc1', 'enc2', 'enc3', 'fc']
    else:  # Other architectures
        layer_names = ['enc1', 'enc2', 'fc']
    
    for i, layer_name in enumerate(layer_names):
        if hasattr(model, layer_name):
            layer = getattr(model, layer_name)
            if hasattr(layer, 'forward_lif'):
                forward_hook = layer.forward_lif.register_forward_hook(get_forward_activation(f'Layer_{i+1}'))
                hooks.append(forward_hook)
            if hasattr(layer, 'backward_lif'):
                backward_hook = layer.backward_lif.register_forward_hook(get_backward_activation(f'Layer_{i+1}'))
                hooks.append(backward_hook)
    
    if len(hooks) == 0:
        logger.warning("No hooks registered - model architecture may not match expected structure")
        return {}, {}, np.array([])
    
    sample_count = 0
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(dataloader):
            if sample_count >= max_samples:
                break
                
            data = data.to(device)
            targets = targets.to(device)
            
            functional.reset_net(model)
            
            _ = model(data, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False)
            
            functional.reset_net(model)
            batch_size = data.shape[0]
            num_classes = 10  
            if hasattr(model, 'reverse'):
                one_hot_labels = torch.zeros(batch_size, num_classes, device=device)
                one_hot_labels.scatter_(1, targets.unsqueeze(1), 1)
                T = 4 
                one_hot_expanded = one_hot_labels.unsqueeze(0).repeat(T, 1, 1)
                try:
                    _ = model.reverse(one_hot_expanded, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False)
                except Exception as e:
                    logger.warning(f"Failed to collect backward features: {e}")
            
            all_targets.append(targets.cpu())
            sample_count += data.shape[0]
    
    for hook in hooks:
        hook.remove()
    
    if was_training:
        model.train()
    else:
        model.eval()
    
    functional.reset_net(model)
    
    final_forward_features = {}
    for layer_name, feat_list in forward_features.items():
        if feat_list:
            final_forward_features[layer_name] = torch.cat(feat_list, dim=0).numpy()
    
    final_backward_features = {}
    for layer_name, feat_list in backward_features.items():
        if feat_list:
            final_backward_features[layer_name] = torch.cat(feat_list, dim=0).numpy()
    
    all_targets = torch.cat(all_targets, dim=0).numpy() if all_targets else np.array([])
    
    return final_forward_features, final_backward_features, all_targets


def plot_feature_alignment(features_by_epoch, targets_by_epoch, save_path, total_epochs=25):
    """
    绘制前向后向网络特征对齐可视化
    
    Args:
        features_by_epoch: {epoch: {'forward': {layer_name: features}, 'backward': {layer_name: features}}}
        targets_by_epoch: {epoch: targets}  
        save_path: 保存路径
        total_epochs: 总训练轮数
    """
    data_save_path = save_path.replace('.png', '_data.pkl')
    if not features_by_epoch:
        logger.warning("No feature alignment data to plot")
        return
    
    epochs_to_show = sorted(features_by_epoch.keys())
    if len(epochs_to_show) > 5:
        indices = np.linspace(0, len(epochs_to_show)-1, 5, dtype=int)
        epochs_to_show = [epochs_to_show[i] for i in indices]
    
    first_epoch = epochs_to_show[0]
    if first_epoch not in features_by_epoch:
        logger.warning("No feature data for plotting")
        return
    
    forward_layer_names = []
    backward_layer_names = []
    
    if 'forward' in features_by_epoch[first_epoch]:
        forward_layer_names = sorted(list(features_by_epoch[first_epoch]['forward'].keys()))
    if 'backward' in features_by_epoch[first_epoch]:
        backward_layer_names = sorted(list(features_by_epoch[first_epoch]['backward'].keys()))
    
    all_layer_names = sorted(list(set(forward_layer_names + backward_layer_names)))
    num_layers = len(all_layer_names)
    
    if num_layers == 0:
        logger.warning("No layers found for feature alignment plotting")
        return
    
    logger.info(f"Found layers - Forward: {forward_layer_names}, Backward: {backward_layer_names}")
    
    fig, axes = plt.subplots(num_layers, len(epochs_to_show), 
                           figsize=(4*len(epochs_to_show), 3*num_layers))
    
    if num_layers == 1:
        axes = axes.reshape(1, -1)
    if len(epochs_to_show) == 1:
        axes = axes.reshape(-1, 1)
    
    bright_colors = plt.cm.Set1(np.linspace(0, 1, 10))  
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h']  
    
    tsne_data = {
        'epochs': epochs_to_show,
        'layers': all_layer_names,
        'forward_coordinates': {},  # {(epoch, layer): features_2d} 
        'backward_coordinates': {}, # {(epoch, layer): features_2d}  
        'targets': {},             # {epoch: targets}
        'forward_metrics': {},     # {(epoch, layer): {'silhouette': score, 'knn_acc': acc}}
        'backward_metrics': {}     # {(epoch, layer): {'silhouette': score, 'knn_acc': acc}}
    }
    
    layers_to_show = [layer for layer in all_layer_names if layer in ['Layer_4', 'Layer_5']]
    if not layers_to_show:
        layers_to_show = all_layer_names[-2:] if len(all_layer_names) >= 2 else all_layer_names
    
    logger.info(f"Displaying layers: {layers_to_show}")
    
    fig, axes = plt.subplots(len(layers_to_show), len(epochs_to_show), 
                           figsize=(4*len(epochs_to_show), 3*len(layers_to_show)))
    
    if len(layers_to_show) == 1:
        axes = axes.reshape(1, -1)
    if len(epochs_to_show) == 1:
        axes = axes.reshape(-1, 1)
    
    for row, layer_name in enumerate(layers_to_show):
        for col, epoch in enumerate(epochs_to_show):
            ax = axes[row, col]
            
            if epoch in features_by_epoch:
                targets = targets_by_epoch[epoch]
                
                forward_features = None
                backward_features = None
                
                if ('forward' in features_by_epoch[epoch] and 
                    layer_name in features_by_epoch[epoch]['forward']):
                    forward_features = features_by_epoch[epoch]['forward'][layer_name]
                
                if ('backward' in features_by_epoch[epoch] and 
                    layer_name in features_by_epoch[epoch]['backward']):
                    backward_features = features_by_epoch[epoch]['backward'][layer_name]
                
                if features.shape[0] < 30:  
                    ax.text(0.5, 0.5, 'Insufficient Data', ha='center', va='center', 
                           transform=ax.transAxes, fontsize=10)
                    continue
                
                if np.std(features.flatten()) < 1e-6:
                    ax.text(0.5, 0.5, 'Constant Features', ha='center', va='center', 
                           transform=ax.transAxes, fontsize=10)
                    continue
                
                try:
                    if features.shape[1] > 50:
                        feature_std = np.std(features, axis=0)
                        if np.any(feature_std < 1e-8):
                            valid_features = feature_std >= 1e-8
                            features = features[:, valid_features]
                        
                        n_components = min(50, features.shape[1], features.shape[0]-1)
                        pca = PCA(n_components=n_components, random_state=42)
                        features_pca = pca.fit_transform(features)
                    else:
                        features_pca = features
                    
                    perplexity = min(30, max(5, features.shape[0]//4))
                    tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity, 
                               init='pca', learning_rate=200, max_iter=500)
                    features_2d = tsne.fit_transform(features_pca)
                    
                    sil_score = silhouette_score(features_2d, targets)
                    
                    knn = KNeighborsClassifier(n_neighbors=5)
                    knn.fit(features_2d, targets)
                    knn_acc = knn.score(features_2d, targets)
                    
                    tsne_data['coordinates'][(epoch, layer_name)] = features_2d.copy()
                    tsne_data['targets'][epoch] = targets.copy()
                    tsne_data['metrics'][(epoch, layer_name)] = {
                        'silhouette': sil_score,
                        'knn_acc': knn_acc
                    }
                    
                    for class_idx in range(10): 
                        mask = targets == class_idx
                        if np.sum(mask) > 0:
                            ax.scatter(features_2d[mask, 0], features_2d[mask, 1], 
                                     c=[bright_colors[class_idx]], s=35, alpha=0.8, 
                                     marker=markers[class_idx])
                    
                except Exception as e:
                    ax.text(0.5, 0.5, f'Error: {str(e)[:20]}...', ha='center', va='center', 
                           transform=ax.transAxes, fontsize=10)
                    continue
            else:
                ax.text(0.5, 0.5, 'No Data', ha='center', va='center', 
                       transform=ax.transAxes, fontsize=10)
            
            ax.set_xticks([])
            ax.set_yticks([])
            for spine in ax.spines.values():
                spine.set_visible(False)
            
            if row == 0:
                if epoch == 0:
                    ax.set_title('Before training', fontsize=11, fontweight='bold')
                elif epoch == 1:
                    ax.set_title('After 1 epoch', fontsize=11, fontweight='bold')
                else:
                    ax.set_title(f'After {epoch} epochs', fontsize=11, fontweight='bold')
            
            if col == 0:
                ax.set_ylabel(f'Layer {row+1}', fontsize=11, fontweight='bold')
    
    cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                      'dog', 'frog', 'horse', 'ship', 'truck']
    legend_elements = []
    for i in range(10):
        legend_elements.append(plt.Line2D([0], [0], marker=markers[i], color='w', 
                                        markerfacecolor=bright_colors[i], markersize=8,
                                        label=cifar10_classes[i], linestyle='None'))
    
    fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(0.98, 0.5),
              fontsize=10, title='CIFAR-10 Classes', title_fontsize=11)
    
    fig.suptitle('Feature Alignments Across Layers', 
                fontsize=16, fontweight='bold', y=0.95)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.9, hspace=0.3, wspace=0.1, right=0.85)
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    try:
        with open(data_save_path, 'wb') as f:
            pickle.dump(tsne_data, f)
        logger.info(f"Feature alignment plot saved to {save_path}")
        logger.info(f"t-SNE data saved to {data_save_path}")
    except Exception as e:
        logger.warning(f"Failed to save t-SNE data: {e}")


def regenerate_feature_alignment_from_data(data_path, output_path, style='default'):
    """
    从保存的t-SNE数据重新生成特征对齐图
    
    Args:
        data_path: 保存的pickle数据路径
        output_path: 输出图片路径
        style: 可视化风格 ('default', 'paper', 'colorful')
    """
    try:
        with open(data_path, 'rb') as f:
            tsne_data = pickle.load(f)
        
        epochs_to_show = tsne_data['epochs']
        layer_names = tsne_data['layers']
        
        fig, axes = plt.subplots(len(layer_names), len(epochs_to_show), 
                               figsize=(4*len(epochs_to_show), 3*len(layer_names)))
        
        if len(layer_names) == 1:
            axes = axes.reshape(1, -1)
        if len(epochs_to_show) == 1:
            axes = axes.reshape(-1, 1)
        
        bright_colors = plt.cm.Set1(np.linspace(0, 1, 10))
        markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h']
        
        for row, layer_name in enumerate(layer_names):
            for col, epoch in enumerate(epochs_to_show):
                ax = axes[row, col]
                
                if (epoch, layer_name) in tsne_data['coordinates']:
                    features_2d = tsne_data['coordinates'][(epoch, layer_name)]
                    targets = tsne_data['targets'][epoch]
                    
                    for class_idx in range(10):
                        mask = targets == class_idx
                        if np.sum(mask) > 0:
                            ax.scatter(features_2d[mask, 0], features_2d[mask, 1], 
                                     c=[bright_colors[class_idx]], s=35, alpha=0.8, 
                                     marker=markers[class_idx])
                
                ax.set_xticks([])
                ax.set_yticks([])
                for spine in ax.spines.values():
                    spine.set_visible(False)
                
                if row == 0:
                    if epoch == 0:
                        ax.set_title('Before training', fontsize=11, fontweight='bold')
                    elif epoch == 1:
                        ax.set_title('After 1 epoch', fontsize=11, fontweight='bold')
                    else:
                        ax.set_title(f'After {epoch} epochs', fontsize=11, fontweight='bold')
                
                if col == 0:
                    ax.set_ylabel(f'Layer {row+1}', fontsize=11, fontweight='bold')
        
        cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                          'dog', 'frog', 'horse', 'ship', 'truck']
        legend_elements = []
        for i in range(10):
            legend_elements.append(plt.Line2D([0], [0], marker=markers[i], color='w', 
                                            markerfacecolor=bright_colors[i], markersize=8,
                                            label=cifar10_classes[i], linestyle='None'))
        
        fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(0.98, 0.5),
                  fontsize=10, title='CIFAR-10 Classes', title_fontsize=11)
        
        fig.suptitle('Feature Alignments Across Layers', 
                    fontsize=16, fontweight='bold', y=0.95)
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.9, hspace=0.3, wspace=0.1, right=0.85)
        
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✅ 重新生成特征对齐图: {output_path}")
        return True
        
    except Exception as e:
        print(f"❌ 无法重新生成图片: {e}")
        return False


def log_gradient_statistics(model, logger, epoch):
    """Log detailed gradient statistics for each layer"""
    logger.info(f"=== Epoch {epoch} - Gradient Statistics ===")
    
    # Check each layer based on architecture
    if hasattr(model, 'enc5'):  # 5-layer architectures (CNN, CNN_Pool, etc.)
        layer_names = ['enc1', 'enc2', 'enc3', 'enc4', 'enc5', 'output']
    elif hasattr(model, 'enc3') and hasattr(model, 'fc'):  # Updated WideShallow (3-layer)
        layer_names = ['enc1', 'enc2', 'enc3', 'fc', 'output']
    else:  # Other architectures
        layer_names = ['enc1', 'enc2', 'fc', 'output']
    
    for layer_name in layer_names:
        if hasattr(model, layer_name):
            layer = getattr(model, layer_name)
            logger.info(f"  {layer_name}:")
            
            # Forward parameters
            if hasattr(layer, 'forward_layer') and layer.forward_layer.weight.grad is not None:
                fw_grad = layer.forward_layer.weight.grad
                logger.info(f"    forward_weight: norm={fw_grad.norm():.6f}, mean={fw_grad.mean():.6f}, std={fw_grad.std():.6f}")
            else:
                logger.info(f"    forward_weight: No gradient")
            
            # Backward parameters  
            if hasattr(layer, 'backward_layer') and layer.backward_layer.weight.grad is not None:
                bw_grad = layer.backward_layer.weight.grad
                logger.info(f"    backward_weight: norm={bw_grad.norm():.6f}, mean={bw_grad.mean():.6f}, std={bw_grad.std():.6f}")
            else:
                logger.info(f"    backward_weight: No gradient")
            
            # Readout parameters
            if hasattr(layer, 'use_readout') and layer.use_readout:
                if hasattr(layer, 'forward_readout_conv') and layer.forward_readout_conv.weight.grad is not None:
                    ro_grad = layer.forward_readout_conv.weight.grad
                    logger.info(f"    readout_conv: norm={ro_grad.norm():.6f}, mean={ro_grad.mean():.6f}, std={ro_grad.std():.6f}")
                elif hasattr(layer, 'forward_readout_fc') and layer.forward_readout_fc.weight.grad is not None:
                    ro_grad = layer.forward_readout_fc.weight.grad
                    logger.info(f"    readout_fc: norm={ro_grad.norm():.6f}, mean={ro_grad.mean():.6f}, std={ro_grad.std():.6f}")
                else:
                    logger.info(f"    readout: No gradient")


def log_activation_statistics(spike_activations, signals, logger, epoch):
    """Log detailed SPIKE activation statistics for each layer (sparsity of LIF neurons)"""
    logger.info(f"=== Epoch {epoch} - SPIKE Activation Statistics (LIF Sparsity) ===")
    
    # Forward spike activations
    logger.info(f"  Forward Spike Activations (LIF outputs):")
    for i, act in enumerate(spike_activations):
        if isinstance(act, torch.Tensor) and i > 0 and i < len(spike_activations) - 1:  # Skip input and final output
            # Compute spike statistics (these should be binary 0/1 values)
            min_val = act.min().item()
            max_val = act.max().item()
            mean_val = act.mean().item()  # This is the firing rate
            std_val = act.std().item()
            sparsity = (act == 0).float().mean().item()  # Fraction of silent neurons
            
            logger.info(f"    Spike[{i}] - Shape: {list(act.shape)}, Min: {min_val:.1f}, Max: {max_val:.1f}")
            logger.info(f"               Firing Rate: {mean_val:.4f}, Std: {std_val:.4f}, Sparsity: {sparsity:.4f}")
        elif isinstance(act, torch.Tensor):
            logger.info(f"    Layer[{i}] - Shape: {list(act.shape)} (input/output, not LIF)")
        else:
            logger.info(f"    Layer[{i}]: Non-tensor ({type(act)})")
    
    # Backward signals (readout features - different from spikes)
    if signals:
        logger.info(f"  Backward Readout Features:")
        for i, sig in enumerate(signals[:-1]):  # Skip target
            if isinstance(sig, torch.Tensor) and i > 0:  # Skip input
                min_val = sig.min().item()
                max_val = sig.max().item()
                mean_val = sig.mean().item()
                std_val = sig.std().item()
                sparsity = (sig == 0).float().mean().item()
                
                logger.info(f"    Readout[{i}] - Shape: {list(sig.shape)}, Min: {min_val:.4f}, Max: {max_val:.4f}")
                logger.info(f"                 Mean: {mean_val:.4f}, Std: {std_val:.4f}, Sparsity: {sparsity:.4f}")
            elif isinstance(sig, torch.Tensor):
                logger.info(f"    Signal[{i}] - Shape: {list(sig.shape)} (input/target)")
            else:
                logger.info(f"    Signal[{i}]: Non-tensor ({type(sig)})")
    
    logger.info(f"==================================================================")


# Argument Parser
parser = argparse.ArgumentParser(description='Spiking CNN BiDistill Training')

# Method configuration (from yaml method section)
parser.add_argument('--method', type=str, default='BSD', help='Training method (BSD or BP)')
parser.add_argument('--architecture', type=str, default='cnn_pool', help='Model architecture (only cnn_pool supported)')
parser.add_argument('--task', type=str, default='classification', help='Task type')
parser.add_argument('--fw_bn', type=int, default=2, help='Forward batch normalization')
parser.add_argument('--bw_bn', type=int, default=2, help='Backward batch normalization')
parser.add_argument('--bias_init', type=str, default='zero', help='Bias initialization')
parser.add_argument('--bn_affine', type=int, default=1, help='Batch normalization affine')

# Dataset configuration (from yaml dataset section)
parser.add_argument('--dataset', type=str, default='CIFAR10', help='Dataset name')
parser.add_argument('--batchsize', type=int, default=128, help='Batch size')
parser.add_argument('--accumulation_steps', type=int, default=1, help='Number of gradient accumulation steps')
parser.add_argument('--num_chn', type=int, default=3, help='Number of input channels')

# Training configuration (from yaml training section)
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
parser.add_argument('--lr_F', type=float, default=0.001, help='Forward learning rate')
parser.add_argument('--lr_B', type=float, default=0.001, help='Backward learning rate')
parser.add_argument('--wd_F', type=float, default=0.000001, help='Forward weight decay')
parser.add_argument('--wd_B', type=float, default=0.000001, help='Backward weight decay')
parser.add_argument('--mmt_F', type=float, default=0.8, help='Forward momentum')
parser.add_argument('--mmt_B', type=float, default=0.8, help='Backward momentum')
parser.add_argument('--warmup', type=int, default=30, help='Warmup steps')
parser.add_argument('--tmax', type=int, default=70, help='T_max for CosineAnnealingLR')
parser.add_argument('--eta_min', type=float, default=0, help='Minimum learning rate')
parser.add_argument('--optimizer', type=str, default='AdamW', help='Optimizer type (SGD or AdamW)')
parser.add_argument('--beta1', type=float, default=0.9, help='Adam beta1 parameter')
parser.add_argument('--beta2', type=float, default=0.95, help='Adam beta2 parameter')
parser.add_argument('--eps', type=float, default=1e-8, help='Adam epsilon parameter')
parser.add_argument('--GradC', type=int, default=0, help='Gradient centralization')
parser.add_argument('--grad_clip_F', type=float, default=0.3, help='Gradient clipping for forward pass')
parser.add_argument('--grad_clip_B', type=float, default=0.3, help='Gradient clipping for backward pass')
parser.add_argument('--loss_scale_C', type=float, default=0, help='Loss scale C')
parser.add_argument('--loss_scale_ssl', type=float, default=0, help='Loss scale SSL')
parser.add_argument('--filter_target', type=float, default=0.0, help='Filter target')
parser.add_argument('--seed', type=int, default=12042, help='Random seed')
parser.add_argument('--patience', type=int, default=100, help='Early stopping patience')

# Three-Term Loss Configuration (from yaml training section)
parser.add_argument('--use_three_term_loss', type=bool, default=True, help='Use three-term loss')
parser.add_argument('--three_term_inv_weight', type=float, default=25.0, help='Three-term invariance weight')
parser.add_argument('--three_term_var_weight', type=float, default=0.0, help='Three-term variance weight')
parser.add_argument('--three_term_cov_weight', type=float, default=0.0, help='Three-term covariance weight')
parser.add_argument('--three_term_gamma', type=float, default=1.0, help='Three-term variance threshold')
parser.add_argument('--three_term_infonce_temperature', type=float, default=0.1, help='InfoNCE temperature')
parser.add_argument('--three_term_mse_start_ratio', type=float, default=0.0, help='Initial MSE ratio')
parser.add_argument('--three_term_mse_end_ratio', type=float, default=0.0, help='Final MSE ratio')
parser.add_argument('--three_term_use_relaxed_contrastive', type=bool, default=True, help='Use ReCo loss')
parser.add_argument('--three_term_reco_lambda', type=float, default=0.6, help='ReCo lambda parameter')

# Wandb parameters (from yaml training section)
parser.add_argument('--use_wandb', type=bool, default=False, help='Use Weights & Biases for logging')
parser.add_argument('--wandb_project', type=str, default='bidistill-spiking-cnn', help='Wandb project name')
parser.add_argument('--wandb_run_name', type=str, default=None, help='Wandb run name')
parser.add_argument('--wandb_tags', type=str, nargs='*', default=['cifar10', 'spiking_cnn', 'reco', 'no readout', 'wide'], help='Wandb tags')

# Data augmentation (from yaml augmentation section)
parser.add_argument('--use_augment', type=bool, default=True, help='Enable data augmentation')
parser.add_argument('--augment_magnitude', type=int, default=4, help='Augmentation magnitude')

# Label encoding (from yaml label_encoding section)
parser.add_argument('--use_label_encoding', type=bool, default=False, help='Enable label encoding')
parser.add_argument('--encoding_dim', type=int, default=128, help='Label encoding dimension')
parser.add_argument('--encoding_sparsity', type=float, default=0.9, help='Label encoding sparsity')

# Spiking parameters (from yaml spiking section)
parser.add_argument('--time_steps', type=int, default=4, help='Number of time steps')
parser.add_argument('--tau', type=float, default=2.0, help='LIF time constant')
parser.add_argument('--v_threshold_forward', type=float, default=0.8, help='Forward voltage threshold')
parser.add_argument('--v_threshold_backward', type=float, default=0.8, help='Backward voltage threshold')
parser.add_argument('--atan_alpha', type=float, default=2.0, help='Atan surrogate gradient alpha')
parser.add_argument('--use_prelif_for_loss', type=bool, default=True, help='Use pre-LIF features for loss')
parser.add_argument('--backend', type=str, default='cupy', help='SpikingJelly backend')
parser.add_argument('--use_spike_readout', type=bool, default=False, help='Use spike readout system')
parser.add_argument('--readout_expand_factor', type=float, default=0.0625, help='Readout expand factor')

# Noise training (from yaml noise_training section)
parser.add_argument('--use_noise_training', type=bool, default=False, help='Enable noise training')
parser.add_argument('--noise_mean', type=float, default=0.0, help='Noise mean')
parser.add_argument('--noise_std', type=float, default=0.05, help='Noise std')

args = parser.parse_args()

# Get normalization parameters based on dataset (for noise training)
normalize_mean = None
normalize_std = None

if args.dataset in ['CIFAR10', 'CIFAR100']:
    normalize_mean = [0.485, 0.456, 0.406]
    normalize_std = [0.229, 0.224, 0.225]
elif args.dataset in ['SVHN']:
    normalize_mean = [0.4376821, 0.4437697, 0.47280442]
    normalize_std = [0.19803012, 0.20101562, 0.19703614]
elif args.dataset in ['MNIST', 'FashionMNIST', 'MNIST_CNN', 'FashionMNIST_CNN']:
    normalize_mean = [0.1307]
    normalize_std = [0.3081]
elif args.dataset in ['TinyImageNet', 'STL10', 'STL10_cls']:
    normalize_mean = [0.485, 0.456, 0.406]
    normalize_std = [0.229, 0.224, 0.225]

# Store normalization parameters in args for noise training
args.normalize_mean = normalize_mean
args.normalize_std = normalize_std

# Weight alignment and spike alignment tracking
weight_alignment_data = defaultdict(list)  # {layer_name: [alignment_values]}
spike_alignment_data = defaultdict(list)  # {layer_name: [alignment_values]}
epoch_list = []
spike_epoch_list = []

# Feature alignment tracking for t-SNE visualization
feature_alignment_data = {}  # {epoch: {layer_name: features}}
feature_targets_data = {}    # {epoch: targets}
feature_epoch_list = []      # [0, 5, 10, 15, 20, 25]

# Set up logging with dataset-specific directory structure (after YAML is loaded and args parsed)
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_dir = f'logs/spiking_cnn_{args.dataset.lower()}/{current_datetime}'
checkpoint_dir = f'checkpoints/spiking_cnn_{args.dataset.lower()}/{current_datetime}'
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
log_file = os.path.join(log_dir, f'spiking_bsd_cnn_readout_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')

# Disable root logger's default console handler to prevent debug info in terminal
logging.getLogger().handlers.clear()
logging.getLogger().setLevel(logging.WARNING)  # Only show warnings/errors from root logger

# File logger for detailed logs
file_logger = logging.getLogger('file_logger')
file_logger.setLevel(logging.INFO)
file_logger.propagate = False  # Don't propagate to root logger
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
file_logger.addHandler(file_handler)

# Console logger for important messages only
console_logger = logging.getLogger('console_logger')
console_logger.setLevel(logging.INFO)
console_logger.propagate = False  # Don't propagate to root logger
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(logging.Formatter('%(message)s'))
console_logger.addHandler(console_handler)

logger = file_logger  # Use file logger for detailed logs
info = console_logger.info  # Use console logger for important info

# Read pre-LIF configuration
args.use_prelif_for_loss = getattr(args, 'use_prelif_for_loss', False)

# Read temporal aggregation configuration
args.temporal_aggregation = getattr(args, 'temporal_aggregation', 'mean')  # 'mean' or 'last'

# Set device
args.device = device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

# Load dataset
train_loader, test_loader = get_dataset(args)

# Determine number of classes based on dataset
if args.dataset == "MNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "FashionMNIST" or args.dataset == "FashionMNIST_CNN" or args.dataset == "MNIST_CNN" or args.dataset == "STL10_cls":
    num_classes = 10
elif args.dataset == "CIFAR100":
    num_classes = 100
elif args.dataset == "TinyImageNet":
    num_classes = 200
else:
    raise ValueError(f"Unsupported dataset: {args.dataset}")

# Handle label encoding
label_encodings = None
encoding_output_dim = num_classes  # Default output dimension

if getattr(args, 'use_label_encoding', False):
    logger.info("Creating Latin Square label encodings...")
    
    # Create label encodings
    label_encodings, position_owner = create_latin_square_label_encodings(
        num_classes=num_classes,
        timesteps=args.time_steps, 
        encoding_dim=args.encoding_dim,
        sparsity=args.encoding_sparsity,
        device=args.device
    )
    
    # Update output dimension to match encoding
    encoding_output_dim = args.encoding_dim
    
    # Log encoding statistics
    logger.info(f"Label encodings shape: {label_encodings.shape}")
    logger.info(f"Encoding dimension: {args.encoding_dim}, Sparsity: {args.encoding_sparsity}")
    
    # Verify orthogonality
    total_overlaps = 0
    for i in range(num_classes):
        for j in range(i+1, num_classes):
            overlap = torch.sum(label_encodings[i] * label_encodings[j]).item()
            total_overlaps += overlap
    logger.info(f"Label encoding orthogonality check - Total overlaps: {total_overlaps} (should be 0)")
    
    # Store label encodings in args for later use
    args.label_encodings = label_encodings
    
    # Log detailed label encodings
    logger.info("=== LABEL ENCODING CONFIGURATION ===")
    logger.info(f"Use label encoding: {args.use_label_encoding}")
    logger.info(f"Encoding dimension: {args.encoding_dim}")
    logger.info(f"Sparsity level: {args.encoding_sparsity}")
    logger.info(f"Output dimension: {encoding_output_dim}")
    logger.info("=====================================")

# Update output dimension in architecture if using label encoding
if getattr(args, 'use_label_encoding', False):
    # For CNN models, we need to modify the final layer to output encoding_dim instead of num_classes
    args.num_classes = encoding_output_dim
else:
    args.num_classes = num_classes

# Define the model - only cnn_pool architecture supported
if args.architecture == "cnn_pool":
    model = SpikingCNNModel_Pool(args).to(args.device)
else:
    raise ValueError(f"Unsupported architecture: {args.architecture}. Only 'cnn_pool' is supported.")

# Initialize weights if specified
weight_init = getattr(args, 'weight_init', 'default')
if weight_init != 'default':
    init_mean = getattr(args, 'init_mean', 0.0)
    init_std = getattr(args, 'init_std', 0.02)
    initialize_model_weights(model, weight_init, init_mean, init_std)
    logger.info(f"Weight initialization: {weight_init} (mean={init_mean}, std={init_std})")

# Set SpikingJelly backend (after model is created)
if hasattr(args, 'backend') and args.backend == 'cupy':
    try:
        functional.set_backend(model, 'cupy')
        info(f"✅ SpikingJelly backend set to: cupy")
        logger.info(f"SpikingJelly backend successfully set to: cupy")
    except Exception as e:
        info(f"⚠️ Failed to set cupy backend, falling back to torch: {e}")
        logger.warning(f"Failed to set cupy backend: {e}")
        functional.set_backend(model, 'torch')
else:
    functional.set_backend(model, 'torch')
    info(f"✅ SpikingJelly backend set to: torch")
    logger.info(f"SpikingJelly backend set to: torch")

# Log readout system status
logger.info(f"🔍 READOUT SYSTEM DEBUG:")
logger.info(f"  Mode: Training uses readout features for BSD loss, inference uses spikes")
logger.info(f"  Architecture: Each LIF → FC → BN → float features (for training only)")

# Log stride precompute results
if hasattr(model, 'get_stride_precompute_results'):
    stride_results = model.get_stride_precompute_results()
    if stride_results:
        logger.info(f"📏 STRIDE PRECOMPUTE RESULTS (readout_expand_factor={getattr(args, 'readout_expand_factor', 1)}):")
        for result in stride_results:
            layer = result['layer_name']
            target = result['target_stride']
            safe = result['safe_stride']
            is_safe = result['is_safe']
            expected = result['expected_size']
            expand_factor = result['expand_factor']
            
            status_msg = "✅ SAFE" if is_safe else f"⚠️  UNSAFE→fallback"
            logger.info(f"  {layer}: target_stride={target} → safe_stride={safe} ({status_msg})")
            logger.info(f"      expected_spatial_size={expected}×{expected}, expand_factor={expand_factor}")
    else:
        logger.info(f"📏 STRIDE PRECOMPUTE: No shrinking applied (expand_factor={getattr(args, 'readout_expand_factor', 1)})")

# Check parameter collection
logger.info(f"Forward params total: {len(model.forward_params)}")
logger.info(f"Backward params total: {len(model.backward_params)}")

# Define the optimizer (following original)
if hasattr(args, 'optimizer') and args.optimizer == 'AdamW':
    forward_optimizer = optim.AdamW(model.forward_params, lr=args.lr_F, 
                                   betas=(args.beta1, args.beta2), 
                                   eps=args.eps, weight_decay=args.wd_F)
    backward_optimizer = optim.AdamW(model.backward_params, lr=args.lr_B,
                                    betas=(args.beta1, args.beta2),
                                    eps=args.eps, weight_decay=args.wd_B)
else:
    forward_optimizer = optim.SGD(model.forward_params, lr=args.lr_F, momentum=args.mmt_F, weight_decay=args.wd_F)
    backward_optimizer = optim.SGD(model.backward_params, lr=args.lr_B, momentum=args.mmt_B, weight_decay=args.wd_B)

if args.tmax != 0:
    forward_scheduler = CosineAnnealingLR(forward_optimizer, T_max=args.tmax, eta_min=args.eta_min)
    backward_scheduler = CosineAnnealingLR(backward_optimizer, T_max=args.tmax, eta_min=args.eta_min)

# Define the loss function - use original loss
criterion = C2Loss_Classification_CNN_Test(args)
logger.info("Using Three-Term Loss with readout float features")
CELoss = nn.CrossEntropyLoss()

# Label encoding removed - using simplified batch one-hot approach

# Define fixed T10 for test loop (all class prototypes)
if args.dataset == "MNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "FashionMNIST" or args.dataset == "FashionMNIST_CNN" or args.dataset == "MNIST_CNN" or args.dataset == "STL10_cls":
    args.T10 = torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device)
    num_classes = 10
elif args.dataset == "CIFAR100":
    args.T10 = torch.Tensor([i for i in range(100)]).long().to(device)
    num_classes = 100
elif args.dataset == "TinyImageNet":
    args.T10 = torch.Tensor([i for i in range(200)]).long().to(device)
    num_classes = 200

# Helper function to create batch one-hot encoding
def create_batch_onehot(targets, num_classes, device):
    """
    Create one-hot encoding for the current batch targets.
    Args:
        targets: batch targets [n]
        num_classes: number of classes
        device: target device
    Returns:
        one-hot matrix [n, num_classes]
    """
    batch_size = len(targets)
    # Ensure targets are on the correct device
    targets = targets.to(device)
    one_hot = torch.zeros(batch_size, num_classes, dtype=torch.float32, device=device)
    one_hot.scatter_(1, targets.unsqueeze(1), 1.0)
    return one_hot

# Training loop 
best_test_loss = np.inf
best_test_acc = 0.0  # Track best test accuracy for early stopping
patience_counter = 0  # Early stopping counter
args.train_steps = 0

# Use the already created checkpoint directory
ckpt_dir = checkpoint_dir

# Record training start time
training_start_time = datetime.now()

# Initialize Wandb if requested
if args.use_wandb and WANDB_AVAILABLE:
    # Generate run name if not provided
    if args.wandb_run_name is None:
        args.wandb_run_name = f"{args.method}_{args.architecture}_{args.dataset}_{training_start_time.strftime('%m%d_%H%M')}"
    
    # Add default tags
    default_tags = [args.method, args.architecture, args.dataset, "readout"]
    all_tags = list(set(default_tags + args.wandb_tags))
    
    wandb.init(
        project=args.wandb_project,
        name=args.wandb_run_name,
        tags=all_tags,
        config=vars(args)
    )
    print(f"📊 Wandb initialized: {wandb.run.url}")
elif args.use_wandb and not WANDB_AVAILABLE:
    print("⚠️ Warning: Wandb requested but not installed. Install with: pip install wandb")
    args.use_wandb = False

print(f"\n🚀 Starting training with {args.method} method on {args.dataset}")
print(f"📊 Architecture: {args.architecture}")
readout_status = "Enabled" if getattr(args, 'use_spike_readout', True) else "Disabled"
print(f"🔥 Using spike readout system: {readout_status}")
print(f"⏰ Training started at: {training_start_time.strftime('%Y-%m-%d %H:%M:%S')}")
if args.use_wandb and WANDB_AVAILABLE:
    print(f"📈 Wandb tracking: Enabled")
print("-" * 60)

logger.info(f"=== SPIKING BSD CNN WITH READOUT CONFIGURATION ===")
logger.info(f"Forward input: Real data samples [batch, 3, 32, 32]")
logger.info(f"Backward input: Batch one-hot labels (biological plausibility)")
logger.info(f"Readout system: Enabled")
logger.info(f"  Training: LIF spikes → FC+BN → float features → Three-Term Loss")
logger.info(f"  Inference: Pure spiking network (readout branches ignored)")

# Log surrogate function information
surrogate_name = getattr(args, 'surrogate_function', 'atan')
logger.info(f"Surrogate function: {surrogate_name}")
if surrogate_name == 's2nn' and hasattr(args, 'surrogate_params'):
    params = args.surrogate_params
    alpha = params.get('s2nn_alpha', 4.0)
    beta = params.get('s2nn_beta', 1.0)
    logger.info(f"  S2NN parameters: alpha={alpha} (negative part), beta={beta} (positive part)")

logger.info(f"Time steps: {args.time_steps}, Tau: {args.tau}")
logger.info(f"V_threshold (F/B): {args.v_threshold_forward}/{args.v_threshold_backward}")
logger.info(f"Filter target: {args.filter_target}")
logger.info(f"Loss scales - C: {args.loss_scale_C}, SSL: {args.loss_scale_ssl}")

# Log gradient accumulation settings
if args.accumulation_steps > 1:
    effective_batch_size = args.batchsize * args.accumulation_steps
    logger.info(f"Gradient accumulation: Enabled (steps={args.accumulation_steps})")
    logger.info(f"  Batch size: {args.batchsize} → Effective batch size: {effective_batch_size}")
else:
    logger.info(f"Gradient accumulation: Disabled (batch size: {args.batchsize})")

# Log data augmentation settings
augment_status = f"Data augmentation: {'Enabled' if getattr(args, 'use_augment', False) else 'Disabled'}"
if getattr(args, 'use_augment', False):
    augment_status += f" (RandAugment magnitude: {getattr(args, 'augment_magnitude', 4)})"
logger.info(augment_status)

# Log label encoding settings
label_encoding_status = f"Label encoding: {'Enabled' if getattr(args, 'use_label_encoding', False) else 'Disabled (using one-hot)'}"
if getattr(args, 'use_label_encoding', False):
    label_encoding_status += f" (Dim: {args.encoding_dim}, Sparsity: {args.encoding_sparsity})"
logger.info(label_encoding_status)

# Log noise training settings
noise_training_status = f"Noise training: {'Enabled' if getattr(args, 'use_noise_training', False) else 'Disabled'}"
if getattr(args, 'use_noise_training', False):
    noise_training_status += f" (mean={getattr(args, 'noise_mean', 0.0)}, std={getattr(args, 'noise_std', 0.05)})"
    noise_training_status += " [Applied to both training and validation]"
logger.info(noise_training_status)

logger.info(f"==================================================")

for args.epoch in range(args.epochs):
    # Training loop
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    
    # Log detailed statistics at the beginning of each epoch (including epoch 0)
    if args.epoch == 0:
        logger.info(f"=== Epoch {args.epoch} - Initial Statistics ===")
        logger.info(f"Logging initial statistics before any training...")
        
    #     # Collect initial features (before training)
    #     logger.info(f"Collecting initial feature alignment data...")
    #     initial_forward_features, initial_backward_features, initial_targets = extract_layer_features(model, test_loader, device, max_samples=300)
    #     if initial_forward_features or initial_backward_features:
    #         feature_alignment_data[0] = {
    #             'forward': initial_forward_features,
    #             'backward': initial_backward_features
    #         }
    #         feature_targets_data[0] = initial_targets
    #         feature_epoch_list.append(0)
    #         logger.info(f"Initial features collected - Forward: {len(initial_forward_features)} layers, Backward: {len(initial_backward_features)} layers")
    #     else:
    #         logger.warning("Failed to collect initial features")
    # else:
    #     logger.info(f"=== Epoch {args.epoch} - Pre-Epoch Statistics ===")
    
    # Collect statistics from first batch
    first_batch_activations = None
    first_batch_signals = None

    # Gradient accumulation counter
    accumulation_counter = 0

    # Progress bar for training
    pbar = tqdm(train_loader, desc=f"Epoch {args.epoch}/{args.epochs-1}",
                leave=False, dynamic_ncols=True)

    for batch_idx, (data, target) in enumerate(pbar):
        # Reset network state
        functional.reset_net(model)

        # Apply Gaussian noise to training data if enabled
        if getattr(args, 'use_noise_training', False):
            data_fw = add_gaussian_noise(
                data.to(args.device),
                mean=args.noise_mean,
                std=args.noise_std,
                normalize_mean=args.normalize_mean,
                normalize_std=args.normalize_std
            )
        else:
            data_fw = data

        # Create backward network input: label encoding or one-hot
        if getattr(args, 'use_label_encoding', False):
            # Use label encoding: select encodings for batch targets
            batch_label_encodings = args.label_encodings[target]  # [N, T, L]
            # Convert from [N, T, L] to [T, N, L] for time-first format
            data_bw = batch_label_encodings.permute(1, 0, 2).to(args.device)
        else:
            # Create one-hot encoding for current batch targets
            data_bw = create_batch_onehot(target, num_classes, args.device)
        
        if args.method == "BSD":
            # ====================================================================
            # FORWARD PASS with LOCAL LEARNING (BSD Algorithm)
            # ====================================================================
            # CRITICAL: detach_grad=True enforces LOCALITY in learning
            # - Stops gradient flow between layers during forward propagation
            # - Each layer learns independently using only local signals
            # - This is a core principle of Bidirectional Spike Distillation (BSD)
            # - Enables biologically plausible learning without global backprop
            # ====================================================================
            if batch_idx == 0:
                if args.use_prelif_for_loss:
                    activations = model(data_fw.to(args.device), detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=True)
                    first_batch_activations = None
                else:
                    activations = model(data_fw.to(args.device), detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False)
                    first_batch_activations = None
            else:
                if args.use_prelif_for_loss:
                    activations = model(data_fw.to(args.device), detach_grad=True, use_prelif_for_loss=True)
                else:
                    activations = model(data_fw.to(args.device), detach_grad=True, use_prelif_for_loss=False)


            # ====================================================================
            # BACKWARD PASS with LOCAL LEARNING (BSD Algorithm)
            # ====================================================================
            # CRITICAL: detach_grad=True enforces LOCALITY in learning
            # - Stops gradient flow between layers during backward propagation
            # - Backward network learns independently from forward network
            # - Creates symmetric bidirectional learning without weight transport
            # - This is essential for biologically plausible credit assignment
            # ====================================================================
            if batch_idx == 0:
                if args.use_prelif_for_loss:
                    signals = model.reverse(data_bw.to(args.device), detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=True)
                    first_batch_signals = None
                else:
                    signals = model.reverse(data_bw.to(args.device), detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False)
                    first_batch_signals = None
            else:
                if args.use_prelif_for_loss:
                    signals = model.reverse(data_bw.to(args.device), detach_grad=True, use_prelif_for_loss=True)
                else:
                    signals = model.reverse(data_bw.to(args.device), detach_grad=True, use_prelif_for_loss=False)
            
            # Process tensors for BSD loss
            activations_for_loss = []
            for i, act in enumerate(activations):
                if act is None:  # Skip None features (when readout is disabled)
                    continue
                elif isinstance(act, torch.Tensor) and len(act.shape) >= 3:  # Has time dimension
                    if getattr(args, 'use_label_encoding', False):
                        # For label encoding: [T, N, ...] → [N, T, ...] (keep time dimension)
                        act_processed = act.permute(1, 0, *range(2, len(act.shape)))
                    else:
                        # For one-hot: Temporal aggregation [T, N, ...] → [N, ...]
                        if args.temporal_aggregation == 'last':
                            act_processed = act[-1]  # Use last time step
                        else:
                            act_processed = act.mean(dim=0)  # Time average (default)
                    activations_for_loss.append(act_processed)
                else:
                    activations_for_loss.append(act)
            
            signals_for_loss = []
            for i, sig in enumerate(signals):
                if sig is None:  # Skip None features (when readout is disabled)
                    continue
                elif isinstance(sig, torch.Tensor) and len(sig.shape) >= 3:  # Has time dimension
                    if getattr(args, 'use_label_encoding', False):
                        # For label encoding: [T, N, ...] → [N, T, ...] (keep time dimension)
                        sig_processed = sig.permute(1, 0, *range(2, len(sig.shape)))
                    else:
                        # For one-hot: Temporal aggregation [T, N, ...] → [N, ...]
                        if args.temporal_aggregation == 'last':
                            sig_processed = sig[-1]  # Use last time step
                        else:
                            sig_processed = sig.mean(dim=0)  # Time average (default)
                    signals_for_loss.append(sig_processed)
                else:
                    signals_for_loss.append(sig)
            
            
            if batch_idx == 0:  # Only log for first batch
                use_readout = getattr(args, 'use_spike_readout', False)
                logger.info(f"Using {'readout system' if use_readout else 'spike-only system'} for Three-Term Loss")
                logger.info(f"Forward features: {[act.shape if isinstance(act, torch.Tensor) else type(act) for act in activations_for_loss]}")
                logger.info(f"Backward features: {[sig.shape if isinstance(sig, torch.Tensor) else type(sig) for sig in signals_for_loss]}")
            
            # Use Three-Term Loss (time-averaged activations and signals)
            loss, loss_item = criterion(activations_for_loss, signals_for_loss, target.to(args.device), method="local",
                                       current_epoch=args.epoch, total_epochs=args.epochs)
            
        elif args.method == "BP":
            # BP mode: only forward pass, no backward signals 
            if batch_idx == 0:
                if args.use_prelif_for_loss:
                    activations = model(data_fw.to(args.device), detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=True)
                    first_batch_activations = None  
                    first_batch_signals = None  # No backward signals in BP mode
                else:
                    activations = model(data_fw.to(args.device), detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=False)
                    first_batch_activations = None 
                    first_batch_signals = None  # No backward signals in BP mode
            else:
                activations = model(data_fw.to(args.device), detach_grad=False, use_prelif_for_loss=args.use_prelif_for_loss)
            
            # Average over time for final output (final output has time dimension)
            final_output = activations[-1]
            if isinstance(final_output, torch.Tensor) and len(final_output.shape) == 3:  # [T, N, classes]
                final_output = final_output.mean(dim=0)  # → [N, classes]
            loss = CELoss(final_output, target.to(args.device))
            loss_item = loss.item()
        
        if args.train_steps < args.warmup:
            loss *= (batch_idx+1) / args.warmup

        # Scale loss by accumulation steps for gradient accumulation
        loss = loss / args.accumulation_steps

        # Only zero gradients at the start of accumulation cycle
        if accumulation_counter == 0:
            forward_optimizer.zero_grad()
            backward_optimizer.zero_grad()

        loss.backward()

        # Increment accumulation counter
        accumulation_counter += 1

        # Check gradients for first batch and log detailed statistics
        if batch_idx == 0 and accumulation_counter == 1:
            forward_grad_count = sum(1 for p in model.forward_params if p.grad is not None)
            backward_grad_count = sum(1 for p in model.backward_params if p.grad is not None)
            logger.info(f"Forward params with gradients: {forward_grad_count}/{len(model.forward_params)}")
            logger.info(f"Backward params with gradients: {backward_grad_count}/{len(model.backward_params)}")

            # Log detailed statistics for first batch after backward pass
            log_gradient_statistics(model, logger, args.epoch)
            # if first_batch_activations is not None:
            #     log_activation_statistics(first_batch_activations, first_batch_signals, logger, args.epoch)

        # Only perform optimizer step after accumulating gradients
        if accumulation_counter == args.accumulation_steps:
            if args.GradC == 1:
                gradient_centralization(model)

            # Apply gradient clipping
            if args.grad_clip_F != 0:
                torch.nn.utils.clip_grad_norm_(model.forward_params, args.grad_clip_F)
            if args.grad_clip_B != 0:
                torch.nn.utils.clip_grad_norm_(model.backward_params, args.grad_clip_B)

            forward_optimizer.step()
            backward_optimizer.step()

            # Reset accumulation counter
            accumulation_counter = 0

            args.train_steps += 1

        train_loss += loss_item

        # Update progress bar with current loss and effective batch info
        effective_batch_size = args.batchsize * args.accumulation_steps
        pbar.set_postfix({
            'Loss': f'{loss_item:.4f}',
            'Avg_Loss': f'{train_loss/(batch_idx+1):.4f}',
            'Eff_BS': effective_batch_size if args.accumulation_steps > 1 else None
        })
    
    if args.tmax != 0:
        forward_scheduler.step()
        backward_scheduler.step()
    
    # Test evaluation after each epoch
    model.eval()
    test_loss = 0
    test_acc = 0
    test_counter = 0
    
    logger.info(f"Evaluating on test set after epoch {args.epoch}...")
    
    with torch.no_grad():
        for data, target in test_loader:
            # Reset network state
            functional.reset_net(model)

            # Apply Gaussian noise to validation data if noise training is enabled
            if getattr(args, 'use_noise_training', False):
                data_fw = add_gaussian_noise(
                    data.to(args.device),
                    mean=args.noise_mean,
                    std=args.noise_std,
                    normalize_mean=args.normalize_mean,
                    normalize_std=args.normalize_std
                )
            else:
                data_fw = data

            # Create backward network input: label encoding or one-hot (same as training)
            if getattr(args, 'use_label_encoding', False):
                # Use label encoding: select encodings for batch targets
                batch_label_encodings = args.label_encodings[target]  # [N, T, L]
                # Convert from [N, T, L] to [T, N, L] for time-first format
                data_bw = batch_label_encodings.permute(1, 0, 2).to(args.device)
            else:
                # Create one-hot encoding for current batch targets
                data_bw = create_batch_onehot(target, num_classes, args.device)
            
            if args.method == "BSD":
                # ====================================================================
                # TEST/VALIDATION with LOCAL LEARNING (BSD Algorithm)
                # ====================================================================
                # CRITICAL: detach_grad=True maintains CONSISTENCY with training
                # - Uses same local learning approach during validation
                # - Ensures fair evaluation of BSD algorithm's performance
                # - Forward and backward networks remain independent
                # ====================================================================
                activations = model(data_fw.to(args.device), detach_grad=True, use_prelif_for_loss=args.use_prelif_for_loss)
                signals = model.reverse(data_bw.to(args.device), detach_grad=True, use_prelif_for_loss=args.use_prelif_for_loss)
                
                # Process tensors for BSD loss (same as training)
                activations_for_loss = []
                for i, act in enumerate(activations):
                    if act is None:  # Skip None features (when readout is disabled)
                        continue
                    elif isinstance(act, torch.Tensor) and len(act.shape) >= 3:  # Has time dimension
                        if getattr(args, 'use_label_encoding', False):
                            # For label encoding: [T, N, ...] → [N, T, ...] (keep time dimension)
                            act_processed = act.permute(1, 0, *range(2, len(act.shape)))
                        else:
                            # For one-hot: Time average [T, N, ...] → [N, ...]
                            act_processed = act.mean(dim=0)
                        activations_for_loss.append(act_processed)
                    else:
                        activations_for_loss.append(act)
                
                signals_for_loss = []
                for i, sig in enumerate(signals):
                    if sig is None:  # Skip None features (when readout is disabled)
                        continue
                    elif isinstance(sig, torch.Tensor) and len(sig.shape) >= 3:  # Has time dimension
                        if getattr(args, 'use_label_encoding', False):
                            # For label encoding: [T, N, ...] → [N, T, ...] (keep time dimension)
                            sig_processed = sig.permute(1, 0, *range(2, len(sig.shape)))
                        else:
                            # For one-hot: Time average [T, N, ...] → [N, ...]
                            sig_processed = sig.mean(dim=0)
                        signals_for_loss.append(sig_processed)
                    else:
                        signals_for_loss.append(sig)

                # Use Three-Term Loss (time-averaged activations and signals)
                loss, loss_item = criterion(activations_for_loss, signals_for_loss, target.to(args.device), method="local",
                                           current_epoch=args.epoch, total_epochs=args.epochs)
                
                # Get final output for accuracy calculation 
                if getattr(args, 'use_label_encoding', False):
                    # For label encoding, use cosine similarity
                    final_output_raw = activations_for_loss[-1]  # [N, T, L] (already transformed)
                    batch_accuracy_tensor, predictions = compute_cosine_similarity_accuracy(
                        final_output_raw,
                        target.to(args.device),
                        args.label_encodings
                    )
                    test_acc += (predictions == target.to(args.device)).sum().item()
                else:
                    # For one-hot, use standard argmax
                    final_output = activations_for_loss[-1]  # This is the time-averaged readout output
                    test_acc += (torch.argmax(final_output, dim=1) == target.to(args.device)).sum().item()
                    
            elif args.method == "BP":
                activations = model(data_fw.to(args.device), detach_grad=True, use_prelif_for_loss=args.use_prelif_for_loss)
                final_output = activations[-1]
                # Time average for final output
                if isinstance(final_output, torch.Tensor) and len(final_output.shape) == 3:  # [T, N, classes]
                    final_output = final_output.mean(dim=0)  # → [N, classes]
                loss_item = CELoss(final_output, target.to(args.device)).item()
                
                # Calculate accuracy based on encoding type
                if getattr(args, 'use_label_encoding', False):
                    # For label encoding, use cosine similarity
                    final_output_for_acc = activations[-1].permute(1, 0, 2)  # [N, T, L]
                    batch_accuracy_tensor, predictions = compute_cosine_similarity_accuracy(
                        final_output_for_acc,
                        target.to(args.device),
                        args.label_encodings
                    )
                    test_acc += (predictions == target.to(args.device)).sum().item()
                else:
                    test_acc += (torch.argmax(final_output, dim=1) == target.to(args.device)).sum().item()
            
            test_counter += len(data)
            test_loss += loss_item * len(data)
    
    test_loss = test_loss / test_counter
    test_acc = test_acc / test_counter
    
    # Log gradient norms
    forward_norm, forward_count = compute_gradient_norm(model.forward_params)
    backward_norm, backward_count = compute_gradient_norm(model.backward_params)
    logger.info(f"Epoch {args.epoch} - Gradient norms: Forward={forward_norm:.6f} ({forward_count} params), Backward={backward_norm:.6f} ({backward_count} params)")
    
    # Calculate metrics
    train_avg_loss = train_loss / len(train_loader)
    
    # Terminal output (clean and concise)
    print(f"Epoch {args.epoch:3d}/{args.epochs-1} | Train Loss: {train_avg_loss:.4f} | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")
    
    # # Record weight alignment every epoch
    # logger.info(f"Computing weight alignment for epoch {args.epoch}...")
    # weight_alignments = compute_weight_alignment(model)
    # epoch_list.append(args.epoch)
    # for layer_name, alignment in weight_alignments.items():
    #     weight_alignment_data[layer_name].append(alignment)
    
    # logger.info(f"Weight alignments: {weight_alignments}")
    
    # # Record feature alignment every 5 epochs (including epoch 0 already collected)
    # if args.epoch % 5 == 0 and args.epoch > 0:
    #     logger.info(f"Computing feature alignment for epoch {args.epoch}...")
    #     epoch_forward_features, epoch_backward_features, epoch_targets = extract_layer_features(model, test_loader, device, max_samples=300)
    #     if epoch_forward_features or epoch_backward_features:
    #         feature_alignment_data[args.epoch] = {
    #             'forward': epoch_forward_features,
    #             'backward': epoch_backward_features
    #         }
    #         feature_targets_data[args.epoch] = epoch_targets
    #         feature_epoch_list.append(args.epoch)
    #         logger.info(f"Feature alignment collected - Forward: {len(epoch_forward_features)} layers, Backward: {len(epoch_backward_features)} layers at epoch {args.epoch}")
    #     else:
    #         logger.warning(f"Failed to collect feature alignment at epoch {args.epoch}")
    
    # # Record spike alignment every 20 epochs (starting from epoch 20, skip epoch 0)
    # if args.epoch % 20 == 0 and args.epoch > 0:
    #     logger.info(f"Computing spike alignment for epoch {args.epoch}...")
    #     if first_batch_activations is not None and first_batch_signals is not None:
    #         spike_alignments = compute_spike_alignment(first_batch_activations, first_batch_signals)
    #         spike_epoch_list.append(args.epoch)
    #         for layer_name, alignment in spike_alignments.items():
    #             spike_alignment_data[layer_name].append(alignment)
    #         logger.info(f"Spike alignments: {spike_alignments}")
    #     elif first_batch_activations is not None and args.method == "BP":
    #         # For BP mode, we only have activations, no signals to compare
    #         logger.info(f"Spike alignment skipped for BP mode (no backward signals to compare)")
    #     else:
    #         logger.warning(f"No spike activation/signal data available for spike alignment at epoch {args.epoch}")
    
    # Detailed logging
    logger.info(f"Epoch {args.epoch} Summary:")
    logger.info(f"  Train Loss: {train_avg_loss:.6f}")
    logger.info(f"  Test Loss: {test_loss:.6f}")
    logger.info(f"  Test Acc: {test_acc:.6f}")
    
    # Wandb logging with structured naming
    if args.use_wandb and WANDB_AVAILABLE:
        wandb_log = {
            'epoch': args.epoch,
            'train/loss': train_avg_loss,
            'train/steps': args.train_steps,
            'test/loss': test_loss,
            'test/accuracy': test_acc,
            'gradients/forward_norm': forward_norm,
            'gradients/backward_norm': backward_norm,
        }
        
        # Add learning rates if schedulers are used
        if args.tmax != 0:
            wandb_log['train/lr_forward'] = forward_optimizer.param_groups[0]['lr']
            wandb_log['train/lr_backward'] = backward_optimizer.param_groups[0]['lr']
        
        wandb.log(wandb_log)
    
    # Early stopping and checkpoint logic
    improved = False
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_test_loss = test_loss
        best_model = copy.deepcopy(model)
        args.best_epoch = args.epoch
        patience_counter = 0
        improved = True
        
        # Save checkpoint
        ckpt_path = os.path.join(ckpt_dir, f'best_model_epoch_{args.epoch}_acc_{test_acc:.4f}.pth')
        torch.save({
            'epoch': args.epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_forward_state_dict': forward_optimizer.state_dict(),
            'optimizer_backward_state_dict': backward_optimizer.state_dict(),
            'best_test_acc': best_test_acc,
            'best_test_loss': best_test_loss,
            'args': args
        }, ckpt_path)
        
        logger.info(f"  🎉 NEW BEST! Test accuracy improved to {test_acc:.4f}")
        logger.info(f"  💾 Checkpoint saved: {ckpt_path}")
        print(f"     🎉 New best Test Acc: {test_acc:.4f} - Model saved!")
        
        # Log checkpoint info to Wandb
        if args.use_wandb and WANDB_AVAILABLE:
            wandb.run.summary.update({
                'best_checkpoint_path': ckpt_path,
                'best_model_epoch': args.epoch,
                'best_model_test_acc': test_acc,
                'best_model_test_loss': test_loss
            })
        
    else:
        patience_counter += 1
        logger.info(f"  ⏳ No improvement for {patience_counter}/{args.patience} epochs")
        
        if patience_counter >= args.patience:
            logger.info(f"  🛑 EARLY STOPPING: No improvement for {args.patience} consecutive epochs")
            print(f"Early stopping triggered after {patience_counter} epochs without improvement.")
            break

# Load best model for final reporting
model = best_model
model.eval()

# Log final test results to Wandb
if args.use_wandb and WANDB_AVAILABLE:
    wandb.log({
        'test/loss': best_test_loss,
        'test/accuracy': best_test_acc,
        'best/epoch': args.best_epoch,
        'best/test_accuracy': best_test_acc,
        'best/test_loss': best_test_loss,
    })
    
    # Log training summary
    wandb.run.summary.update({
        'best_epoch': args.best_epoch,
        'best_test_acc': best_test_acc,
        'best_test_loss': best_test_loss,
        'final_test_acc': best_test_acc,
        'final_test_loss': best_test_loss,
        'total_train_steps': args.train_steps,
    })

# Save best model checkpoint
checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
torch.save({
    'epoch': args.best_epoch,
    'model_state_dict': best_model.state_dict(),
    'test_accuracy': best_test_acc,
    'test_loss': best_test_loss,
    'args': args
}, checkpoint_path)
logger.info(f"Best model checkpoint saved to {checkpoint_path}")

# Generate visualization plots
logger.info("Generating visualization plots...")

# Plot weight alignment
weight_plot_path = os.path.join(log_dir, 'weight_alignment.png')
if weight_alignment_data and epoch_list:
    plot_weight_alignment(weight_alignment_data, weight_plot_path)
else:
    logger.warning("No weight alignment data to plot")

# Plot spike alignment
spike_plot_path = os.path.join(log_dir, 'spike_alignment.png')
if spike_alignment_data and spike_epoch_list:
    plot_spike_alignment(spike_alignment_data, spike_plot_path)
else:
    logger.warning("No spike alignment data to plot")

# Plot dual network feature alignment  
feature_plot_path = os.path.join(log_dir, 'dual_network_feature_alignment.png')
if feature_alignment_data and feature_epoch_list:
    from dual_network_tsne_visualization import create_dual_network_tsne_visualization
    create_dual_network_tsne_visualization(feature_alignment_data, feature_targets_data, feature_plot_path)
else:
    logger.warning("No feature alignment data to plot")

# Terminal output (prominent display)
print("\n" + "="*60)
print("🎯 FINAL RESULTS")
print("="*60)
print(f"Best epoch: {args.best_epoch}")
print(f"Best Test Acc: {best_test_acc:.4f}")
print(f"Best Test Loss: {best_test_loss:.4f}")
print(f"Checkpoint saved to: {checkpoint_path}")
print(f"Plots saved to: {log_dir}")
print("="*60)

# Detailed logging
logger.info(f"🎯 FINAL RESULTS:")
logger.info(f"  Readout system: Enabled")
logger.info(f"  Best test accuracy: {best_test_acc:.6f} (epoch {args.best_epoch})")
logger.info(f"  Best test loss: {best_test_loss:.6f}")
logger.info(f"  Training strategy: Readout features → BSD loss")
logger.info(f"  Inference strategy: Pure spiking network")
logger.info(f"  Early stopping patience: {args.patience}")
training_end_time = datetime.now()
training_duration = training_end_time - training_start_time
logger.info(f"  Training duration: {str(training_duration).split('.')[0]}")
logger.info(f"  Training completed at: {training_end_time.strftime('%Y-%m-%d %H:%M:%S')}")
logger.info(f"  Model checkpoint saved to: {checkpoint_path}")
logger.info(f"  Weight alignment plot saved to: {weight_plot_path}")
logger.info(f"  Spike alignment plot saved to: {spike_plot_path}")
logger.info(f"  Feature alignment plot saved to: {feature_plot_path}")
logger.info(f"  Training completed successfully!")

# Finish Wandb run
if args.use_wandb and WANDB_AVAILABLE:
    logger.info(f"  Wandb run URL: {wandb.run.url}")
    print(f"📈 Wandb run completed: {wandb.run.url}")
    wandb.finish()